from .base import AbstractDataloader
from .bert import BertTrainDataset, BertEvalDataset

import torch


class SasDataloader(AbstractDataloader):
    def __init__(self, args, dataset):
        super().__init__(args, dataset)
        if args.dataloader_output_timestamp:
            self.sas_timestamps = self.calculate_sas_timestamps()
        else:
            self.sas_timestamps = None

    def calculate_sas_timestamps(self):
        sas_timestamps = {}  # user -> timestamps
        for user, dic in self.user2dict.items():
            times = dic['timestamps']
            try:
                time_scale = min(y - x for x, y in zip(times[:-1], times[1:]) if x != y)
            except:
                time_scale = 1
            min_time = min(times)
            sas_timestamps[user] = [round((t-min_time)/time_scale) + 1 for t in times]  # followed authors' original implementation
        return sas_timestamps

    @classmethod
    def code(cls):
        return 'sas'

    def _get_dataset(self, mode):
        if mode == 'train':
            return self._get_train_dataset()
        elif mode == 'val':
            return self._get_eval_dataset('val')
        else:
            return self._get_eval_dataset('test')

    def _get_train_dataset(self):
        train_ranges = self.train_targets
        dataset = SasTrainDataset(self.args, self.dataset, self.train_negative_samples, self.rng, train_ranges, self.sas_timestamps)
        return dataset

    def _get_eval_dataset(self, mode):
        positions = self.validation_targets if mode=='val' else self.test_targets
        dataset = SasEvalDataset(self.args, self.dataset, self.test_negative_samples, positions, self.sas_timestamps)
        return dataset

class SasTrainDataset(BertTrainDataset):
    def __init__(self, args, dataset, negative_samples, rng, train_ranges, sas_timestamps):
        super().__init__(args, dataset, negative_samples, rng, train_ranges)
        self.timestamps = sas_timestamps
        self.marank_mode = args.model_code in ['marank']
        self.marank_max_len = args.marank_max_len  # actual max_len if marank_mode=True
        self.output_user = args.dataloader_output_user

        if self.marank_mode:
            self.user2pos = {user:pos for user, pos in self.train_ranges}

    def __getitem__(self, index):
        user, offset = self.index2user_and_offsets[index]
        if self.marank_mode:
            # sample offset randomly if marank_mode
            # original offset is generated by max_len,train_window (not marank_max_len) to ensure same amount of computation with other models
            pos = self.user2pos[user]
            offset = self.rng.randint(2, pos)  # offset is exclusive
        max_len = self.max_len if not self.marank_mode else self.marank_max_len
        seq = self.user2dict[user]['items']
        beg = max(0, offset-max_len-1)
        end = offset  # exclude offset (meant to be)
        seq = seq[beg:end]

        tokens = seq[:-1]
        padding_len = max_len - len(tokens)
        if self.marank_mode:
            labels = [seq[-1]]
            neg_samples = self.negative_samples[user]
            negative_labels = [self.rng.choice(neg_samples)]

            tokens = tokens + [tokens[-1]] * padding_len
        else:
            labels = seq[1:]
            neg_samples = self.negative_samples[user]  # a pool of negative items to choose from
            negative_labels = [self.rng.choice(neg_samples) for _ in labels]

            tokens = [0] * padding_len + tokens
            labels = [0] * padding_len + labels
            negative_labels = [0] * padding_len + negative_labels

        d = {
            'tokens': torch.LongTensor(tokens),
            'labels': torch.LongTensor(labels),
            'negative_labels': torch.LongTensor(negative_labels,)
        }
        if self.output_timestamps:
            timestamps = self.timestamps[user][beg:end-1]
            timestamps = [0] * padding_len + timestamps
            d['timestamps'] = torch.LongTensor(timestamps)
        if self.output_user:
            d['users'] = torch.LongTensor([user])
        return d


class SasEvalDataset(BertEvalDataset):
    def __init__(self, args, dataset, negative_samples, positions, sas_timestamps):
        super().__init__(args, dataset, negative_samples, positions)
        self.timestamps = sas_timestamps
        self.output_user = args.dataloader_output_user
        self.marank_mode = args.model_code in ['marank']
        self.marank_max_len = args.marank_max_len

    def __getitem__(self, index):
        user, pos = self.positions[index]
        seq = self.user2dict[user]['items']
        max_len = self.max_len if not self.marank_mode else self.marank_max_len
        beg = max(0, pos - max_len)
        # end = pos + 1
        # IMPORTANT:
        ## BERT => INLUCDE ANSWER ITEM AND MASK IT
        ## SAS => EXCLUDE ANSWER ITEM
        ## hence end = pos
        end = pos
        answer = [seq[pos]]
        seq = seq[beg:end]

        negs = self.negative_samples[user]
        # answer = [seq[-1]]
        candidates = answer + negs
        labels = [1] * len(answer) + [0] * len(negs)

        # IMPORTANT : no [MASK]s for sas
        # so the next line is commented
        # seq[-1] = self.special_tokens.mask
        padding_len = max_len - len(seq)
        if self.marank_mode:
            seq = seq + [seq[-1]] * padding_len
        else:
            seq = [0] * padding_len + seq

        tokens = torch.LongTensor(seq)
        candidates = torch.LongTensor(candidates)
        labels = torch.LongTensor(labels)
        d = {'tokens':tokens, 'candidates':candidates, 'labels':labels}
        if self.output_timestamps:
            timestamps = self.timestamps[user][beg:end]
            timestamps = [0] * padding_len + timestamps
            d['timestamps'] = torch.LongTensor(timestamps)
        if self.output_user:
            d['users'] = torch.LongTensor([user])
        return d